Survival Analysis: CatBoost

Author

R. Jerome Dixon

Published

December 3, 2024

Demographics Table

TraitTransplantedWL MortalityP-value
Gender (% Female)44.1% (N=4983)43.1% (N=540)0.709
Age4.0 [0.0, 10.5] (N=4983)0.0 [0.0, 1.6] (N=540)0.000
Weight15.0 [0.0, 33.5] (N=4982)7.6 [2.4, 12.7] (N=540)0.000
Height99.0 [55.0, 143.0] (N=4980)67.0 [45.6, 88.4] (N=540)0.000
BMI16.2 [13.7, 18.6] (N=4980)15.4 [13.6, 17.2] (N=540)0.000
Blood Type A36.0% (N=4983)33.9% (N=540)0.345
Blood Type B13.6% (N=4983)11.5% (N=540)0.190
Blood Type AB4.1% (N=4983)2.2% (N=540)0.042
Blood Type O46.2% (N=4983)52.4% (N=540)0.007
Race White53.0% (N=4983)50.4% (N=540)0.268
Race Black20.0% (N=4983)23.0% (N=540)0.123
Race Other3.1% (N=4983)3.0% (N=540)1.000
VAD %13.3% (N=4983)8.5% (N=540)0.002
eGFR96.1 [72.3, 120.0] (N=4966)82.2 [54.2, 110.2] (N=540)0.000
Albumin3.6 [3.1, 4.1] (N=4836)3.3 [2.8, 3.8] (N=518)0.000
Dialysis %1.2% (N=4978)4.8% (N=540)0.000
Ventilator %15.8% (N=4983)35.0% (N=540)0.000
ECMO %4.6% (N=4983)14.6% (N=540)0.000

Native CatBoost Model

Optuna Hyperparameter Optimization

  • Best is trial #12/50 with value: 0.74
Show the code

model_params_native = {
    'learning_rate': 0.15,
    'depth': 8,
    'colsample_bylevel': 0.75,
    'min_data_in_leaf': 95,
    'l2_leaf_reg': 10.54
}
    

Model

Show the code
#| eval: true
#| echo: false
#| message: false
#| warning: false

import numpy as np
from sklearn.model_selection import train_test_split, StratifiedKFold
import pandas as pd
import optuna
from catboost import CatBoostClassifier, Pool


model_auc = CatBoostClassifier(objective='Logloss',
                               iterations=1000,
                               eval_metric="AUC",
                               **model_params_native, 
                               boosting_type='Ordered',
                               bootstrap_type='MVS',
                               metric_period=25,
                               early_stopping_rounds=100,
                               use_best_model=False, 
                               random_seed=1997)

# Create a Pool object for the training and testing data
train_pool = Pool(X_train, cat_features=cat_index, label=Y_train)
test_pool = Pool(X_test, cat_features=cat_index, label=Y_test)
 

model_auc.fit(train_pool, eval_set=test_pool)
Warning: Overfitting detector is active, thus evaluation metric is calculated on every iteration. 'metric_period' is ignored for evaluation metric.
0:  test: 0.6234443 best: 0.6234443 (0) total: 184ms    remaining: 3m 3s
25: test: 0.7215217 best: 0.7221167 (24)    total: 1.33s    remaining: 50s
50: test: 0.7234322 best: 0.7259924 (37)    total: 2.5s remaining: 46.5s
75: test: 0.7192891 best: 0.7259924 (37)    total: 3.95s    remaining: 48.1s
100:    test: 0.7195511 best: 0.7259924 (37)    total: 5.64s    remaining: 50.2s
125:    test: 0.7081423 best: 0.7259924 (37)    total: 7.33s    remaining: 50.8s
Stopped by overfitting detector  (100 iterations wait)

bestTest = 0.7259924014
bestIteration = 37

<catboost.core.CatBoostClassifier object at 0x0000028C8E6C3E00>

Calibration Plot

Calibrated Model Metrics

             Model       AUC  ...  False Negative (FN)  True Positive (TP)
0  Native_Catboost  0.707553  ...                   35                  67

[1 rows x 13 columns]

SHAP Feature Importance

Show the code
# Retrieve the ranked comparison dataframe from Python
shap_df <- py$shap_df

# Convert the pandas dataframe to an R tibble
shap_tbl <- as_tibble(shap_df) %>% 
  mutate(SHAP_Importance = abs(Importance_SHAP)) %>% 
  arrange(-SHAP_Importance)

# Create a formatted table using huxtable, including the ranks for each method and the 'Direction' column
shap_table <- shap_tbl %>%
  rowid_to_column(var = "Overall Rank") %>%
  select('Overall Rank', 'Feature Id', 
         'Importance_SHAP', 'SHAP_Direction',
         'SHAP_Importance'
 ) 

shap_table %>% 
  DT::datatable(
    rownames = FALSE,
    options = list(
      columnDefs = list(
        list(className = 'dt-center', targets = "_all")
      )
    )
  )

CatBoost - One Hot Encoding (Hybrid)

  • Train
 [1] "Gender"     "Race"       "Blood_Type" "VAD_TCR"    "WL_Oth_Org"
 [6] "Cereb_Vasc" "Diabetes"   "Diag_Code"  "XMatch_Req" "List_Ctr"  
[1] "Race"       "Blood_Type" "Diag_Code" 
'data.frame':   4523 obs. of  44 variables:
 $ outcome                                           : int  0 0 0 1 0 1 0 1 1 1 ...
 $ Age                                               : num  1 0 0 0 2 0 1 11 0 0 ...
  ..- attr(*, "label")= chr "WL AGE AT LISTING IN YEARS"
 $ Gender                                            : Factor w/ 2 levels "F","M": 1 2 2 2 1 2 1 1 1 2 ...
 $ Weight                                            : num  6.71 3.02 6.6 6 19.4 ...
 $ Height                                            : num  72 49.5 66 57 97.5 ...
 $ BMI                                               : num  12.9 12.3 15.2 18.5 20.4 ...
 $ BSA                                               : num  0.366 0.204 0.348 0.308 0.725 ...
 $ PGE_TCR                                           : num  0 0 0 0 0 0 0 0 0 0 ...
  ..- attr(*, "label")= chr "TCR PGE AT LISTING"
 $ ECMO_Reg                                          : num  0 0 0 1 0 0 0 0 0 0 ...
  ..- attr(*, "label")= chr "TCR ECMO AT LISTING"
 $ VAD                                               : num  -1 -1 -1 -1 -1 -1 -1 1 -1 1 ...
 $ VAD_TCR                                           : Factor w/ 5 levels "LVAD","LVAD+RVAD",..: 3 3 3 3 3 3 3 1 3 2 ...
 $ Vent_Reg                                          : num  0 0 0 0 0 0 0 0 0 1 ...
  ..- attr(*, "label")= chr "TCR LIFE SUPPORT VENTILATOR"
 $ WL_Oth_Org                                        : Factor w/ 2 levels "No","Yes": 2 1 1 2 1 1 1 1 1 1 ...
 $ Cereb_Vasc                                        : Factor w/ 3 levels "No","Unknown",..: 1 1 1 1 1 1 1 1 1 1 ...
 $ Diabetes                                          : Factor w/ 5 levels "None","Type I",..: 1 1 1 1 1 1 1 1 1 1 ...
 $ Dialysis                                          : num  -1 -1 -1 1 -1 -1 -1 -1 -1 -1 ...
 $ Inotrop                                           : num  0 0 0 0 1 0 1 0 1 0 ...
  ..- attr(*, "label")= chr "TCR IV INOTROPES AT LISTING"
 $ Creatinine                                        : num  0.13 0.46 0.28 0.67 0.42 0.34 0.34 0.5 0.22 0.21 ...
  ..- attr(*, "label")= chr "TCR MOST RECENT CREAT."
 $ eGFR                                              : num  228.2 44.3 97.1 35.1 95.6 ...
  ..- attr(*, "label")= chr "TCR MOST RECENT CREAT."
 $ Albumin                                           : num  4.4 3.7 3 3.8 4.6 3.9 4.6 3.6 3.2 3.2 ...
  ..- attr(*, "label")= chr "TCR TOTAL SERUM ALBUMIN AT LISTING (Pre 1/1/2007 for adult)"
 $ Prior_HRTX                                        : num  15 2 13 18 13 15 30 11 27 16 ...
 $ Med_Refusals                                      : num  3 8 4 15 5 3 5 5 4 4 ...
 $ Prop_Refusals                                     : num  0.864 0.951 0.852 0.972 0.882 ...
 $ XMatch_Req                                        : Factor w/ 2 levels "No","Yes": 1 1 1 1 1 1 1 1 1 1 ...
 $ List_Yr                                           : num  2020 2020 2020 2020 2020 2020 2020 2020 2020 2020 ...
 $ Policy_Chg                                        : num  1 1 1 1 1 1 1 1 1 1 ...
 $ List_Ctr                                          : Factor w/ 89 levels "ahbent","alfoba",..: 61 34 37 69 16 61 21 71 66 36 ...
 $ Race.Asian                                        : num  0 0 0 0 0 0 0 0 0 0 ...
 $ Race.Black                                        : num  0 0 1 0 0 0 0 0 0 0 ...
 $ Race.Hispanic                                     : num  0 0 0 0 0 1 0 0 0 0 ...
 $ Race.Other                                        : num  0 0 0 0 0 0 0 1 0 0 ...
 $ Race.White                                        : num  1 1 0 1 1 0 1 0 1 1 ...
 $ Blood_Type.A                                      : num  0 0 0 1 0 0 0 0 1 0 ...
 $ Blood_Type.AB                                     : num  0 0 0 0 0 0 0 0 0 0 ...
 $ Blood_Type.B                                      : num  0 0 0 0 0 0 0 0 0 0 ...
 $ Blood_Type.O                                      : num  1 1 1 0 1 1 1 1 0 1 ...
 $ Diag_Code.Congenital Heart Disease With Surgery   : num  1 1 1 1 1 0 0 0 0 0 ...
 $ Diag_Code.Congenital Heart Disease Without Surgery: num  0 0 0 0 0 0 0 0 0 0 ...
 $ Diag_Code.Dilated Cardiomyopathy                  : num  0 0 0 0 0 1 1 1 0 1 ...
 $ Diag_Code.Hypertrophic Cardiomyopathy             : num  0 0 0 0 0 0 0 0 1 0 ...
 $ Diag_Code.Myocarditis                             : num  0 0 0 0 0 0 0 0 0 0 ...
 $ Diag_Code.Other                                   : num  0 0 0 0 0 0 0 0 0 0 ...
 $ Diag_Code.Restrictive Cardiomyopathy              : num  0 0 0 0 0 0 0 0 0 0 ...
 $ Diag_Code.Valvular Heart Disease                  : num  0 0 0 0 0 0 0 0 0 0 ...
'data.frame':   4523 obs. of  44 variables:
 $ outcome      : int  0 0 0 1 0 1 0 1 1 1 ...
 $ Age          : num  1 0 0 0 2 0 1 11 0 0 ...
  ..- attr(*, "label")= chr "WL AGE AT LISTING IN YEARS"
 $ Gender       : Factor w/ 2 levels "F","M": 1 2 2 2 1 2 1 1 1 2 ...
 $ Weight       : num  6.71 3.02 6.6 6 19.4 ...
 $ Height       : num  72 49.5 66 57 97.5 ...
 $ BMI          : num  12.9 12.3 15.2 18.5 20.4 ...
 $ BSA          : num  0.366 0.204 0.348 0.308 0.725 ...
 $ PGE_TCR      : num  0 0 0 0 0 0 0 0 0 0 ...
  ..- attr(*, "label")= chr "TCR PGE AT LISTING"
 $ ECMO_Reg     : num  0 0 0 1 0 0 0 0 0 0 ...
  ..- attr(*, "label")= chr "TCR ECMO AT LISTING"
 $ VAD          : num  -1 -1 -1 -1 -1 -1 -1 1 -1 1 ...
 $ VAD_TCR      : Factor w/ 5 levels "LVAD","LVAD+RVAD",..: 3 3 3 3 3 3 3 1 3 2 ...
 $ Ventilator   : num  0 0 0 0 0 0 0 0 0 1 ...
  ..- attr(*, "label")= chr "TCR LIFE SUPPORT VENTILATOR"
 $ WL_Oth_Org   : Factor w/ 2 levels "No","Yes": 2 1 1 2 1 1 1 1 1 1 ...
 $ Cereb_Vasc   : Factor w/ 3 levels "No","Unknown",..: 1 1 1 1 1 1 1 1 1 1 ...
 $ Diabetes     : Factor w/ 5 levels "None","Type I",..: 1 1 1 1 1 1 1 1 1 1 ...
 $ Dialysis     : num  -1 -1 -1 1 -1 -1 -1 -1 -1 -1 ...
 $ Inotrop      : num  0 0 0 0 1 0 1 0 1 0 ...
  ..- attr(*, "label")= chr "TCR IV INOTROPES AT LISTING"
 $ Creatinine   : num  0.13 0.46 0.28 0.67 0.42 0.34 0.34 0.5 0.22 0.21 ...
  ..- attr(*, "label")= chr "TCR MOST RECENT CREAT."
 $ eGFR         : num  228.2 44.3 97.1 35.1 95.6 ...
  ..- attr(*, "label")= chr "TCR MOST RECENT CREAT."
 $ Albumin      : num  4.4 3.7 3 3.8 4.6 3.9 4.6 3.6 3.2 3.2 ...
  ..- attr(*, "label")= chr "TCR TOTAL SERUM ALBUMIN AT LISTING (Pre 1/1/2007 for adult)"
 $ Txp_Volume   : num  15 2 13 18 13 15 30 11 27 16 ...
 $ Med_Refusals : num  3 8 4 15 5 3 5 5 4 4 ...
 $ Prop_Refusals: num  0.864 0.951 0.852 0.972 0.882 ...
 $ XMatch_Req   : Factor w/ 2 levels "No","Yes": 1 1 1 1 1 1 1 1 1 1 ...
 $ List_Yr      : num  2020 2020 2020 2020 2020 2020 2020 2020 2020 2020 ...
 $ Policy_Chg   : num  1 1 1 1 1 1 1 1 1 1 ...
 $ Listing_Ctr  : Factor w/ 89 levels "ahbent","alfoba",..: 61 34 37 69 16 61 21 71 66 36 ...
 $ Race_Asian   : num  0 0 0 0 0 0 0 0 0 0 ...
 $ Race_Black   : num  0 0 1 0 0 0 0 0 0 0 ...
 $ Race_Hispanic: num  0 0 0 0 0 1 0 0 0 0 ...
 $ Race_Other   : num  0 0 0 0 0 0 0 1 0 0 ...
 $ Race_White   : num  1 1 0 1 1 0 1 0 1 1 ...
 $ Blood_Type_A : num  0 0 0 1 0 0 0 0 1 0 ...
 $ Blood_Type_AB: num  0 0 0 0 0 0 0 0 0 0 ...
 $ Blood_Type_B : num  0 0 0 0 0 0 0 0 0 0 ...
 $ Blood_Type_O : num  1 1 1 0 1 1 1 1 0 1 ...
 $ CHD_Surgery  : num  1 1 1 1 1 0 0 0 0 0 ...
 $ CHD_NoSurgery: num  0 0 0 0 0 0 0 0 0 0 ...
 $ DCM          : num  0 0 0 0 0 1 1 1 0 1 ...
 $ HCM          : num  0 0 0 0 0 0 0 0 1 0 ...
 $ Myocard      : num  0 0 0 0 0 0 0 0 0 0 ...
 $ Other_Diag   : num  0 0 0 0 0 0 0 0 0 0 ...
 $ RCM          : num  0 0 0 0 0 0 0 0 0 0 ...
 $ VHD          : num  0 0 0 0 0 0 0 0 0 0 ...
  • Test

Hybrid CatBoost Model

Show the code

#| echo: false
#| warning: false
#| message: false

import numpy as np

# initialize Train and Test datasets
hybrid_X_train = r.hybrid_train_data
hybrid_y_train = r.hybrid_train_Y
hybrid_Y_train = np.array(hybrid_y_train)  

hybrid_X_test = r.hybrid_test_data
hybrid_y_test = r.hybrid_test_Y
hybrid_Y_test = np.array(hybrid_y_test) 

hybrid_cat_index = get_categorical_indexes(hybrid_X_train)
Feature names are consistent between training and test datasets.

Optuna Hyperparameter Optimization

Model

  • From Optuna Trial #36/50 with value: 0.74
Show the code

model_params_hybrid = {
    'learning_rate': 0.03,
    'depth': 4,
    'colsample_bylevel': 0.12,
    'min_data_in_leaf': 44,
    'l2_leaf_reg': 5.66
}
    
Warning: Overfitting detector is active, thus evaluation metric is calculated on every iteration. 'metric_period' is ignored for evaluation metric.
0:  test: 0.5891032 best: 0.5891032 (0) total: 25.6ms   remaining: 25.5s
Stopped by overfitting detector  (50 iterations wait)

bestTest = 0.7386785449
bestIteration = 255

<catboost.core.CatBoostClassifier object at 0x0000028C8E6F8560>

Calibration Plot

Show the code

import pandas as pd

Y_Pred_hybrid = hybrid_model.predict(hybrid_X_test)
Y_Pred_Proba_hybrid = hybrid_model.predict_proba(hybrid_X_test)[:, 1]  # get the probabilities of the positive class


Y_Pred_Proba_Positive_hybrid = hybrid_model.predict_proba(hybrid_X_test)[:, 1]  # Probabilities of the positive class
Y_Pred_Proba_Negative_hybrid = hybrid_model.predict_proba(hybrid_X_test)[:, 0]  # Probabilities of the negative class

# Converting predictions and actuals into a DataFrame for better readability, including negative class probabilities
hybrid_predictions = pd.DataFrame({
    'Prob_Negative_Class': Y_Pred_Proba_Negative_hybrid,
    'Prob_Positive_Class': Y_Pred_Proba_Positive_hybrid,
    'Predicted': Y_Pred_hybrid,
    'Actual': hybrid_y_test
})
Show the code
hybrid_predictions <- py$hybrid_predictions %>% 
  mutate(Class = ifelse(Actual == 0, "survive", "not_survive"),
         .pred_not_survive = Prob_Positive_Class
         )

# Define the levels you want
factor_levels <- c("survive", "not_survive")

# Set the levels of the 'actuals' column
hybrid_predictions$Class <- factor(hybrid_predictions$Class, levels = rev(factor_levels))

hybrid_predictions %>% 
  cal_plot_logistic(Class, .pred_not_survive)

Calibrated Model Metrics

             Model       AUC  ...  False Negative (FN)  True Positive (TP)
0  Hybrid_Catboost  0.735938  ...                   27                  75

[1 rows x 13 columns]

Final Feature Importances

Show the code
# Retrieve the ranked comparison dataframe from Python
final_shap_df <- py$final_shap_df

# Convert the pandas dataframe to an R tibble
final_shap_tbl <- as_tibble(final_shap_df) %>% 
  arrange(desc(Importance))

# Create a formatted table using huxtable, including the ranks for each method and the 'Direction' column
final_shap_table <- final_shap_tbl %>%
  rowid_to_column(var = "Overall Rank") %>%
  select('Feature Id', 'Importance') 

final_shap_table %>% 
  DT::datatable(
    rownames = FALSE,
    options = list(
      columnDefs = list(
        list(className = 'dt-center', targets = "_all")
      )
    )
  )
Cluster Function for Top Features

Show the code
# Function to cluster data based on optimal clusters
final_clustering <- function(data, optimal_k) {
  # Perform k-means clustering with the optimal number of clusters
  kmeans_res <- kmeans(as.matrix(data), centers = optimal_k, nstart = 25)
  return(kmeans_res$cluster)
}

optimal_k <- 3

# Perform clustering using the optimal number of clusters
final_shap_df <- final_shap_df %>%
  mutate(Cluster = final_clustering(select(., Importance), optimal_k))

# View the clustered data
final_shap_df %>% 
  DT::datatable(
    rownames = FALSE,
    options = list(
      columnDefs = list(
        list(className = 'dt-center', targets = "_all")
      )
    )
  )

‘ECMO_Reg’ is the cutoff feature based on WCSS (within-cluster sum of squares). However we can include a few more additional features that may be potential reasons for ‘Med_Refusals’ as we have several variables that are correlated. For this reason we will set final cutoff at ‘Txp_Volume’ or .02 value for Feature Importance.

CatBoost Model Accuracy Summary

Show the code
model_accuracy
ModelAUCBrier ScoreAccuracyLog LossF1 ScorePrecisionRecallAUPR
Native_Catboost0.7080.08740.6720.3080.29 0.1860.6570.222
Hybrid_Catboost0.7360.08390.6180.2940.2820.1740.7350.294
Show the code
model_confusion_matrix
ModelTrue Negative (TN)False Positive (FP)False Negative (FN)True Positive (TP)
Native_Catboost6052933567
Hybrid_Catboost5433552775

SHAP Value Analysis

Show the code

feature_names = shap_values.feature_names

# Replace '_' with ' ' in each feature name
updated_feature_names = [name.replace('_', ' ') for name in feature_names]

shap_values.feature_names = updated_feature_names

Mean Absolute Value Feature Importance

Show the code
library(ggplot2)
library(plotly)
library(dplyr)


# ggplot bar chart object
p <- ggplot(final_shap_df, aes(x = reorder(`Feature Id`, -Importance), y = Importance, fill = `Feature Id`)) +
  geom_bar(stat = "identity") +
  geom_hline(yintercept = 0.02, linetype = "dashed", color = "red") +  # Add cutoff line
  annotate("text", x = 33.5, y = 0.025, label = "Cutoff@0.02 (Txp_Volume)", color = "red", hjust = 0) +  # Add annotation
  labs(title = "Sorted Mean Absolute SHAP Values", x = "Features", y = "Mean Absolute SHAP Value") +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position = "none")  # Adjust text angle for better readability

# Interactive plotly object
p_interactive <- ggplotly(p, tooltip = c("x", "y"))  # Hover effects with tooltips for both feature name and value

# Display the interactive plot
p_interactive
Show the code
# Save the interactive plot as HTML
htmlwidgets::saveWidget(p_interactive, "sorted_shap_values_interactive.html")

Beeswarm (Top Features - Categorical and Numerical)

Figure 1: Beeswarm Chart

Bar Chart for Feature Importance

Figure 2: Feature Importance Chart

Beeswarm Top Numerical Features

Figure 3: Numerical Beeswarm Chart

Feature Importance Correlation Plot

Figure 4: SHAP Value Correlation Plot

SHAP Partial Dependence Plots

SHAP Partial Dependence Plots